{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a93f115",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp common._scalers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c704dc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83d112c7-18f8-4f20-acad-34e6de54cebf",
   "metadata": {},
   "source": [
    "# TemporalNorm\n",
    "\n",
    "> Temporal normalization has proven to be essential in neural forecasting tasks, as it enables network's non-linearities to express themselves. Forecasting scaling methods take particular interest in the temporal dimension where most of the variance dwells, contrary to other deep learning techniques like `BatchNorm` that normalizes across batch and temporal dimensions, and `LayerNorm` that normalizes across the feature dimension. Currently we support the following techniques: `std`, `median`, `norm`, `norm1`, `invariant`, `revin`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fee5e60b-f53b-44ff-9ace-1f5def7b601d",
   "metadata": {},
   "source": [
    "## References"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9211dd2-99a4-4d67-90cb-bb1f7851685e",
   "metadata": {},
   "source": [
    "* [Kin G. Olivares, David Luo, Cristian Challu, Stefania La Vattiata, Max Mergenthaler, Artur Dubrawski (2023). \"HINT: Hierarchical Mixture Networks For Coherent Probabilistic Forecasting\". Neural Information Processing Systems, submitted. Working Paper version available at arxiv.](https://arxiv.org/abs/2305.07089)\n",
    "* [Taesung Kim and Jinhee Kim and Yunwon Tae and Cheonbok Park and Jang-Ho Choi and Jaegul Choo. \"Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift\". ICLR 2022.](https://openreview.net/pdf?id=cGDAkQo1C0p)\n",
    "* [David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski (2020). \"DeepAR: Probabilistic forecasting with autoregressive recurrent networks\". International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207019301888)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9319296d",
   "metadata": {},
   "source": [
    "![Figure 1. Illustration of temporal normalization (left), layer normalization (center) and batch normalization (right). The entries in green show the components used to compute the normalizing statistics.](imgs_models/temporal_norm.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df2cc55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f08562b-88d8-4e92-aeeb-bc9bc4c61ab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import show_doc\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5201e067-f7c0-4ca3-89a7-d879001b1908",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "plt.rcParams[\"axes.grid\"]=True\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams[\"figure.figsize\"] = (4,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef461e9c",
   "metadata": {},
   "source": [
    "# 1. Auxiliary Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12a249a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def masked_median(x, mask, dim=-1, keepdim=True):\n",
    "    \"\"\" Masked Median\n",
    "\n",
    "    Compute the median of tensor `x` along dim, ignoring values where \n",
    "    `mask` is False. `x` and `mask` need to be broadcastable.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: torch.Tensor to compute median of along `dim` dimension.<br>\n",
    "    `mask`: torch Tensor bool with same shape as `x`, where `x` is valid and False\n",
    "            where `x` should be masked. Mask should not be all False in any column of\n",
    "            dimension dim to avoid NaNs from zero division.<br>\n",
    "    `dim` (int, optional): Dimension to take median of. Defaults to -1.<br>\n",
    "    `keepdim` (bool, optional): Keep dimension of `x` or not. Defaults to True.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `x_median`: torch.Tensor with normalized values.\n",
    "    \"\"\"\n",
    "    x_nan = x.float().masked_fill(mask<1, float(\"nan\"))\n",
    "    x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim)\n",
    "    x_median = torch.nan_to_num(x_median, nan=0.0)\n",
    "    return x_median\n",
    "\n",
    "def masked_mean(x, mask, dim=-1, keepdim=True):\n",
    "    \"\"\" Masked  Mean\n",
    "\n",
    "    Compute the mean of tensor `x` along dimension, ignoring values where \n",
    "    `mask` is False. `x` and `mask` need to be broadcastable.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: torch.Tensor to compute mean of along `dim` dimension.<br>\n",
    "    `mask`: torch Tensor bool with same shape as `x`, where `x` is valid and False\n",
    "            where `x` should be masked. Mask should not be all False in any column of\n",
    "            dimension dim to avoid NaNs from zero division.<br>\n",
    "    `dim` (int, optional): Dimension to take mean of. Defaults to -1.<br>\n",
    "    `keepdim` (bool, optional): Keep dimension of `x` or not. Defaults to True.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `x_mean`: torch.Tensor with normalized values.\n",
    "    \"\"\"\n",
    "    x_nan = x.float().masked_fill(mask<1, float(\"nan\"))\n",
    "    x_mean = x_nan.nanmean(dim=dim, keepdim=keepdim)\n",
    "    x_mean = torch.nan_to_num(x_mean, nan=0.0)\n",
    "    return x_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49d2e338",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(masked_median, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "300e1b4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(masked_mean, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7a486a2",
   "metadata": {},
   "source": [
    "# 2. Scalers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c76dab",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def minmax_statistics(x, mask, eps=1e-6, dim=-1):\n",
    "    \"\"\" MinMax Scaler\n",
    "\n",
    "    Standardizes temporal features by ensuring its range dweels between\n",
    "    [0,1] range. This transformation is often used as an alternative \n",
    "    to the standard scaler. The scaled features are obtained as:\n",
    "\n",
    "    $$\n",
    "    \\mathbf{z} = (\\mathbf{x}_{[B,T,C]}-\\mathrm{min}({\\mathbf{x}})_{[B,1,C]})/\n",
    "        (\\mathrm{max}({\\mathbf{x}})_{[B,1,C]}- \\mathrm{min}({\\mathbf{x}})_{[B,1,C]})\n",
    "    $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: torch.Tensor input tensor.<br>\n",
    "    `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
    "            where `x` should be masked. Mask should not be all False in any column of\n",
    "            dimension dim to avoid NaNs from zero division.<br>\n",
    "    `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.<br>\n",
    "    `dim` (int, optional): Dimension over to compute min and max. Defaults to -1.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `z`: torch.Tensor same shape as `x`, except scaled.\n",
    "    \"\"\"\n",
    "    mask = mask.clone()\n",
    "    mask[mask==0] = torch.inf\n",
    "    mask[mask==1] = 0\n",
    "    x_max = torch.max(torch.nan_to_num(x-mask,nan=-torch.inf), dim=dim, keepdim=True)[0]\n",
    "    x_min = torch.min(torch.nan_to_num(x+mask,nan=torch.inf), dim=dim, keepdim=True)[0]\n",
    "    x_max = x_max.type(x.dtype)\n",
    "    x_min = x_min.type(x.dtype)\n",
    "\n",
    "    # x_range and prevent division by zero\n",
    "    x_range = x_max - x_min\n",
    "    x_range[x_range==0] = 1.0\n",
    "    x_range = x_range + eps\n",
    "    return x_min, x_range"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39fa429b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def minmax_scaler(x, x_min, x_range):\n",
    "    return (x - x_min) / x_range\n",
    "\n",
    "def inv_minmax_scaler(z, x_min, x_range):\n",
    "    return z * x_range + x_min"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99ea1aa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(minmax_statistics, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "334b3d18",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def minmax1_statistics(x, mask, eps=1e-6, dim=-1):\n",
    "    \"\"\" MinMax1 Scaler\n",
    "\n",
    "    Standardizes temporal features by ensuring its range dweels between\n",
    "    [-1,1] range. This transformation is often used as an alternative \n",
    "    to the standard scaler or classic Min Max Scaler. \n",
    "    The scaled features are obtained as:\n",
    "\n",
    "    $$\\mathbf{z} = 2 (\\mathbf{x}_{[B,T,C]}-\\mathrm{min}({\\mathbf{x}})_{[B,1,C]})/ (\\mathrm{max}({\\mathbf{x}})_{[B,1,C]}- \\mathrm{min}({\\mathbf{x}})_{[B,1,C]})-1$$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: torch.Tensor input tensor.<br>\n",
    "    `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
    "            where `x` should be masked. Mask should not be all False in any column of\n",
    "            dimension dim to avoid NaNs from zero division.<br>\n",
    "    `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.<br>\n",
    "    `dim` (int, optional): Dimension over to compute min and max. Defaults to -1.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `z`: torch.Tensor same shape as `x`, except scaled.\n",
    "    \"\"\"\n",
    "    # Mask values (set masked to -inf or +inf)\n",
    "    mask = mask.clone()\n",
    "    mask[mask==0] = torch.inf\n",
    "    mask[mask==1] = 0\n",
    "    x_max = torch.max(torch.nan_to_num(x-mask,nan=-torch.inf), dim=dim, keepdim=True)[0]\n",
    "    x_min = torch.min(torch.nan_to_num(x+mask,nan=torch.inf), dim=dim, keepdim=True)[0]\n",
    "    x_max = x_max.type(x.dtype)\n",
    "    x_min = x_min.type(x.dtype)\n",
    "    \n",
    "    # x_range and prevent division by zero\n",
    "    x_range = x_max - x_min\n",
    "    x_range[x_range==0] = 1.0\n",
    "    x_range = x_range + eps\n",
    "    return x_min, x_range"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a19ed5a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def minmax1_scaler(x, x_min, x_range):\n",
    "    x = (x - x_min) / x_range\n",
    "    z = x * (2) - 1\n",
    "    return z\n",
    "\n",
    "def inv_minmax1_scaler(z, x_min, x_range):\n",
    "    z = (z + 1) / 2\n",
    "    return z * x_range + x_min"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88ccb77b",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(minmax1_statistics, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c187a8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def std_statistics(x, mask, dim=-1, eps=1e-6):\n",
    "    \"\"\" Standard Scaler\n",
    "\n",
    "    Standardizes features by removing the mean and scaling\n",
    "    to unit variance along the `dim` dimension. \n",
    "\n",
    "    For example, for `base_windows` models, the scaled features are obtained as (with dim=1):\n",
    "\n",
    "    $$\\mathbf{z} = (\\mathbf{x}_{[B,T,C]}-\\\\bar{\\mathbf{x}}_{[B,1,C]})/\\hat{\\sigma}_{[B,1,C]}$$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: torch.Tensor.<br>\n",
    "    `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
    "            where `x` should be masked. Mask should not be all False in any column of\n",
    "            dimension dim to avoid NaNs from zero division.<br>\n",
    "    `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.<br>\n",
    "    `dim` (int, optional): Dimension over to compute mean and std. Defaults to -1.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `z`: torch.Tensor same shape as `x`, except scaled.\n",
    "    \"\"\"\n",
    "    x_means = masked_mean(x=x, mask=mask, dim=dim)\n",
    "    x_stds = torch.sqrt(masked_mean(x=(x-x_means)**2, mask=mask, dim=dim))\n",
    "\n",
    "    # Protect against division by zero\n",
    "    x_stds[x_stds==0] = 1.0\n",
    "    x_stds = x_stds + eps\n",
    "    return x_means, x_stds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17f90821",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def std_scaler(x, x_means, x_stds):\n",
    "    return (x - x_means) / x_stds\n",
    "\n",
    "def inv_std_scaler(z, x_mean, x_std):\n",
    "    return (z * x_std) + x_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e077730c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(std_statistics, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c22a041",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def robust_statistics(x, mask, dim=-1, eps=1e-6):\n",
    "    \"\"\" Robust Median Scaler\n",
    "\n",
    "    Standardizes features by removing the median and scaling\n",
    "    with the mean absolute deviation (mad) a robust estimator of variance.\n",
    "    This scaler is particularly useful with noisy data where outliers can \n",
    "    heavily influence the sample mean / variance in a negative way.\n",
    "    In these scenarios the median and amd give better results.\n",
    "    \n",
    "    For example, for `base_windows` models, the scaled features are obtained as (with dim=1):\n",
    "\n",
    "    $$\\mathbf{z} = (\\mathbf{x}_{[B,T,C]}-\\\\textrm{median}(\\mathbf{x})_{[B,1,C]})/\\\\textrm{mad}(\\mathbf{x})_{[B,1,C]}$$\n",
    "        \n",
    "    $$\\\\textrm{mad}(\\mathbf{x}) = \\\\frac{1}{N} \\sum_{}|\\mathbf{x} - \\mathrm{median}(x)|$$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: torch.Tensor input tensor.<br>\n",
    "    `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
    "            where `x` should be masked. Mask should not be all False in any column of\n",
    "            dimension dim to avoid NaNs from zero division.<br>\n",
    "    `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.<br>\n",
    "    `dim` (int, optional): Dimension over to compute median and mad. Defaults to -1.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `z`: torch.Tensor same shape as `x`, except scaled.\n",
    "    \"\"\"\n",
    "    x_median = masked_median(x=x, mask=mask, dim=dim)\n",
    "    x_mad = masked_median(x=torch.abs(x-x_median), mask=mask, dim=dim)\n",
    "\n",
    "    # Protect x_mad=0 values\n",
    "    # Assuming normality and relationship between mad and std\n",
    "    x_means = masked_mean(x=x, mask=mask, dim=dim)\n",
    "    x_stds = torch.sqrt(masked_mean(x=(x-x_means)**2, mask=mask, dim=dim))  \n",
    "    x_mad_aux = x_stds * 0.6744897501960817\n",
    "    x_mad = x_mad * (x_mad>0) + x_mad_aux * (x_mad==0)\n",
    "    \n",
    "    # Protect against division by zero\n",
    "    x_mad[x_mad==0] = 1.0\n",
    "    x_mad = x_mad + eps\n",
    "    return x_median, x_mad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33f3cf28",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def robust_scaler(x, x_median, x_mad):\n",
    "    return (x - x_median) / x_mad\n",
    "\n",
    "def inv_robust_scaler(z, x_median, x_mad):\n",
    "    return z * x_mad + x_median"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7355a5f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(robust_statistics, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8879b00b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def invariant_statistics(x, mask, dim=-1, eps=1e-6):\n",
    "    \"\"\" Invariant Median Scaler\n",
    "\n",
    "    Standardizes features by removing the median and scaling\n",
    "    with the mean absolute deviation (mad) a robust estimator of variance.\n",
    "    Aditionally it complements the transformation with the arcsinh transformation.\n",
    "\n",
    "    For example, for `base_windows` models, the scaled features are obtained as (with dim=1):\n",
    "\n",
    "    $$\\mathbf{z} = (\\mathbf{x}_{[B,T,C]}-\\\\textrm{median}(\\mathbf{x})_{[B,1,C]})/\\\\textrm{mad}(\\mathbf{x})_{[B,1,C]}$$\n",
    "\n",
    "    $$\\mathbf{z} = \\\\textrm{arcsinh}(\\mathbf{z})$$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: torch.Tensor input tensor.<br>\n",
    "    `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
    "            where `x` should be masked. Mask should not be all False in any column of\n",
    "            dimension dim to avoid NaNs from zero division.<br>\n",
    "    `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.<br>\n",
    "    `dim` (int, optional): Dimension over to compute median and mad. Defaults to -1.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `z`: torch.Tensor same shape as `x`, except scaled.\n",
    "    \"\"\"\n",
    "    x_median = masked_median(x=x, mask=mask, dim=dim)\n",
    "    x_mad = masked_median(x=torch.abs(x-x_median), mask=mask, dim=dim)\n",
    "\n",
    "    # Protect x_mad=0 values\n",
    "    # Assuming normality and relationship between mad and std\n",
    "    x_means = masked_mean(x=x, mask=mask, dim=dim)\n",
    "    x_stds = torch.sqrt(masked_mean(x=(x-x_means)**2, mask=mask, dim=dim))        \n",
    "    x_mad_aux = x_stds * 0.6744897501960817\n",
    "    x_mad = x_mad * (x_mad>0) + x_mad_aux * (x_mad==0)\n",
    "\n",
    "    # Protect against division by zero\n",
    "    x_mad[x_mad==0] = 1.0\n",
    "    x_mad = x_mad + eps\n",
    "    return x_median, x_mad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24cca2bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def invariant_scaler(x, x_median, x_mad):\n",
    "    return torch.arcsinh((x - x_median) / x_mad)\n",
    "\n",
    "def inv_invariant_scaler(z, x_median, x_mad):\n",
    "    return torch.sinh(z) * x_mad + x_median"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4b1b313",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(invariant_statistics, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50ba1916",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def identity_statistics(x, mask, dim=-1, eps=1e-6):\n",
    "    \"\"\" Identity Scaler\n",
    "\n",
    "    A placeholder identity scaler, that is argument insensitive.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: torch.Tensor input tensor.<br>\n",
    "    `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
    "            where `x` should be masked. Mask should not be all False in any column of\n",
    "            dimension dim to avoid NaNs from zero division.<br>\n",
    "    `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.<br>\n",
    "    `dim` (int, optional): Dimension over to compute median and mad. Defaults to -1.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `x`: original torch.Tensor `x`.\n",
    "    \"\"\"\n",
    "    # Collapse dim dimension\n",
    "    shape = list(x.shape)\n",
    "    shape[dim] = 1\n",
    "\n",
    "    x_shift = torch.zeros(shape)\n",
    "    x_scale = torch.ones(shape)\n",
    "\n",
    "    return x_shift, x_scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d7b313e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def identity_scaler(x, x_shift, x_scale):\n",
    "    return x\n",
    "\n",
    "def inv_identity_scaler(z, x_shift, x_scale):\n",
    "    return z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56ae8f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(identity_statistics, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e87e828c",
   "metadata": {},
   "source": [
    "# 3. TemporalNorm Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb48423b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TemporalNorm(nn.Module):\n",
    "    \"\"\" Temporal Normalization\n",
    "\n",
    "    Standardization of the features is a common requirement for many \n",
    "    machine learning estimators, and it is commonly achieved by removing \n",
    "    the level and scaling its variance. The `TemporalNorm` module applies \n",
    "    temporal normalization over the batch of inputs as defined by the type of scaler.\n",
    "\n",
    "    $$\\mathbf{z}_{[B,T,C]} = \\\\textrm{Scaler}(\\mathbf{x}_{[B,T,C]})$$\n",
    "\n",
    "    If `scaler_type` is `revin` learnable normalization parameters are added on top of\n",
    "    the usual normalization technique, the parameters are learned through scale decouple\n",
    "    global skip connections. The technique is available for point and probabilistic outputs.\n",
    "\n",
    "    $$\\mathbf{\\hat{z}}_{[B,T,C]} = \\\\boldsymbol{\\hat{\\\\gamma}}_{[1,1,C]} \\mathbf{z}_{[B,T,C]} +\\\\boldsymbol{\\hat{\\\\beta}}_{[1,1,C]}$$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `scaler_type`: str, defines the type of scaler used by TemporalNorm. Available [`identity`, `standard`, `robust`, `minmax`, `minmax1`, `invariant`, `revin`].<br>\n",
    "    `dim` (int, optional): Dimension over to compute scale and shift. Defaults to -1.<br>\n",
    "    `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.<br>\n",
    "    `num_features`: int=None, for RevIN-like learnable affine parameters initialization.<br>\n",
    "\n",
    "    **References**<br>\n",
    "    - [Kin G. Olivares, David Luo, Cristian Challu, Stefania La Vattiata, Max Mergenthaler, Artur Dubrawski (2023). \"HINT: Hierarchical Mixture Networks For Coherent Probabilistic Forecasting\". Neural Information Processing Systems, submitted. Working Paper version available at arxiv.](https://arxiv.org/abs/2305.07089)<br>\n",
    "    \"\"\"\n",
    "    def __init__(self, scaler_type='robust', dim=-1, eps=1e-6, num_features=None):\n",
    "        super().__init__()\n",
    "        compute_statistics = {None: identity_statistics,\n",
    "                              'identity': identity_statistics,\n",
    "                              'standard': std_statistics,\n",
    "                              'revin': std_statistics,\n",
    "                              'robust': robust_statistics,\n",
    "                              'minmax': minmax_statistics,\n",
    "                              'minmax1': minmax1_statistics,\n",
    "                              'invariant': invariant_statistics,}\n",
    "        scalers = {None: identity_scaler,\n",
    "                   'identity': identity_scaler,\n",
    "                   'standard': std_scaler,\n",
    "                   'revin': std_scaler,\n",
    "                   'robust': robust_scaler,\n",
    "                   'minmax': minmax_scaler,\n",
    "                   'minmax1': minmax1_scaler,\n",
    "                   'invariant':invariant_scaler,}\n",
    "        inverse_scalers = {None: inv_identity_scaler,\n",
    "                    'identity': inv_identity_scaler,\n",
    "                    'standard': inv_std_scaler,\n",
    "                    'revin': inv_std_scaler,\n",
    "                    'robust': inv_robust_scaler,\n",
    "                    'minmax': inv_minmax_scaler,\n",
    "                    'minmax1': inv_minmax1_scaler,\n",
    "                    'invariant': inv_invariant_scaler,}\n",
    "        assert (scaler_type in scalers.keys()), f'{scaler_type} not defined'\n",
    "        if (scaler_type=='revin') and (num_features is None):\n",
    "            raise Exception('You must pass num_features for ReVIN scaler.')\n",
    "\n",
    "        self.compute_statistics = compute_statistics[scaler_type]\n",
    "        self.scaler = scalers[scaler_type]\n",
    "        self.inverse_scaler = inverse_scalers[scaler_type]\n",
    "        self.scaler_type = scaler_type\n",
    "        self.dim = dim\n",
    "        self.eps = eps\n",
    "\n",
    "        if (scaler_type=='revin'):\n",
    "            self._init_params(num_features=num_features)\n",
    "\n",
    "    def _init_params(self, num_features):\n",
    "        # Initialize RevIN scaler params to broadcast:\n",
    "        if self.dim==1: # [B,T,C]  [1,1,C]\n",
    "            self.revin_bias = nn.Parameter(torch.zeros(1,1,num_features))\n",
    "            self.revin_weight = nn.Parameter(torch.ones(1,1,num_features))\n",
    "        elif self.dim==-1: # [B,C,T]  [1,C,1]\n",
    "            self.revin_bias = nn.Parameter(torch.zeros(1,num_features,1))\n",
    "            self.revin_weight = nn.Parameter(torch.ones(1,num_features,1))\n",
    "\n",
    "    #@torch.no_grad()\n",
    "    def transform(self, x, mask):\n",
    "        \"\"\" Center and scale the data.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `x`: torch.Tensor shape [batch, time, channels].<br>\n",
    "        `mask`: torch Tensor bool, shape  [batch, time] where `x` is valid and False\n",
    "                where `x` should be masked. Mask should not be all False in any column of\n",
    "                dimension dim to avoid NaNs from zero division.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `z`: torch.Tensor same shape as `x`, except scaled.\n",
    "        \"\"\"\n",
    "        x_shift, x_scale = self.compute_statistics(x=x, mask=mask, dim=self.dim, eps=self.eps)\n",
    "        self.x_shift = x_shift\n",
    "        self.x_scale = x_scale\n",
    "\n",
    "        # Original Revin performs this operation\n",
    "        # z = self.revin_weight * z\n",
    "        # z = z + self.revin_bias\n",
    "        # However this is only valid for point forecast not for\n",
    "        # distribution's scale decouple technique.\n",
    "        if self.scaler_type=='revin':\n",
    "            self.x_shift = self.x_shift + self.revin_bias\n",
    "            self.x_scale = self.x_scale * (torch.relu(self.revin_weight) + self.eps)\n",
    "\n",
    "        z = self.scaler(x, x_shift, x_scale)\n",
    "        return z\n",
    "\n",
    "    #@torch.no_grad()\n",
    "    def inverse_transform(self, z, x_shift=None, x_scale=None):\n",
    "        \"\"\" Scale back the data to the original representation.\n",
    "\n",
    "        **Parameters:**<br>\n",
    "        `z`: torch.Tensor shape [batch, time, channels], scaled.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `x`: torch.Tensor original data.\n",
    "        \"\"\"\n",
    "\n",
    "        if x_shift is None:\n",
    "            x_shift = self.x_shift\n",
    "        if x_scale is None:\n",
    "            x_scale = self.x_scale\n",
    "\n",
    "        # Original Revin performs this operation\n",
    "        # z = z - self.revin_bias\n",
    "        # z = (z / (self.revin_weight + self.eps))\n",
    "        # However this is only valid for point forecast not for\n",
    "        # distribution's scale decouple technique.\n",
    "\n",
    "        x = self.inverse_scaler(z, x_shift, x_scale)\n",
    "        return x\n",
    "\n",
    "    def forward(self, x):\n",
    "        # The gradients are optained from BaseWindows/BaseRecurrent forwards.\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91d7a892",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TemporalNorm, name='TemporalNorm', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3490b4a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TemporalNorm.transform, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df49d4f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TemporalNorm.inverse_transform, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e2968e0",
   "metadata": {},
   "source": [
    "# Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99722125",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7fef46f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Declare synthetic batch to normalize\n",
    "x1 = 10**0 * np.arange(36)[:, None]\n",
    "x2 = 10**1 * np.arange(36)[:, None]\n",
    "\n",
    "np_x = np.concatenate([x1, x2], axis=1)\n",
    "np_x = np.repeat(np_x[None, :,:], repeats=2, axis=0)\n",
    "np_x[0,:,:] = np_x[0,:,:] + 100\n",
    "\n",
    "np_mask = np.ones(np_x.shape)\n",
    "np_mask[:, -12:, :] = 0\n",
    "\n",
    "print(f'x.shape [batch, time, features]={np_x.shape}')\n",
    "print(f'mask.shape [batch, time, features]={np_mask.shape}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da1f93ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Validate scalers\n",
    "x = 1.0*torch.tensor(np_x)\n",
    "mask = torch.tensor(np_mask)\n",
    "scaler = TemporalNorm(scaler_type='standard', dim=1)\n",
    "x_scaled = scaler.transform(x=x, mask=mask)\n",
    "x_recovered = scaler.inverse_transform(x_scaled)\n",
    "\n",
    "plt.plot(x[0,:,0], label='x1', color='#78ACA8')\n",
    "plt.plot(x[0,:,1], label='x2',  color='#E3A39A')\n",
    "plt.title('Before TemporalNorm')\n",
    "plt.xlabel('Time')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "plt.plot(x_scaled[0,:,0], label='x1', color='#78ACA8')\n",
    "plt.plot(x_scaled[0,:,1]+0.1, label='x2+0.1', color='#E3A39A')\n",
    "plt.title(f'TemporalNorm \\'{scaler.scaler_type}\\' ')\n",
    "plt.xlabel('Time')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "plt.plot(x_recovered[0,:,0], label='x1', color='#78ACA8')\n",
    "plt.plot(x_recovered[0,:,1], label='x2', color='#E3A39A')\n",
    "plt.title('Recovered')\n",
    "plt.xlabel('Time')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aa6920e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Validate scalers\n",
    "for scaler_type in [None, 'identity', 'standard', 'robust', 'minmax', 'minmax1', 'invariant', 'revin']:\n",
    "    x = 1.0*torch.tensor(np_x)\n",
    "    mask = torch.tensor(np_mask)\n",
    "    scaler = TemporalNorm(scaler_type=scaler_type, dim=1, num_features=np_x.shape[-1])\n",
    "    x_scaled = scaler.transform(x=x, mask=mask)\n",
    "    x_recovered = scaler.inverse_transform(x_scaled)\n",
    "    assert torch.allclose(x, x_recovered, atol=1e-3), f'Recovered data is not the same as original with {scaler_type}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17e3dbfc-2677-4d1f-85bc-de6343196045",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import pandas as pd\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import NHITS\n",
    "from neuralforecast.utils import AirPassengersDF as Y_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28e5f23d-9a64-4d77-8a27-55fcc765d0b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Unit test for masked predict filtering\n",
    "model = NHITS(h=12,\n",
    "              input_size=12*2,\n",
    "              max_steps=1,\n",
    "              windows_batch_size=None, \n",
    "              n_freq_downsample=[1,1,1],\n",
    "              scaler_type='minmax')\n",
    "\n",
    "nf = NeuralForecast(models=[model], freq='M')\n",
    "nf.fit(df=Y_df)\n",
    "Y_hat = nf.predict(df=Y_df)\n",
    "assert pd.isnull(Y_hat).sum().sum() == 0, 'Predictions should not have NaNs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "383f05b4-e921-4fa6-b2a1-65105b5eebd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import NHITS, RNN\n",
    "from neuralforecast.losses.pytorch import DistributionLoss, HuberLoss, GMM, MAE\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb2095d2-74d4-4b94-bee3-c049aac8494d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Unit test for ReVIN, and its compatibility with distribution's scale decouple\n",
    "Y_df = AirPassengersPanel\n",
    "# del Y_df['trend']\n",
    "\n",
    "# Instantiate BaseWindow model and test revin dynamic dimensionality with hist_exog_list\n",
    "model = NHITS(h=12,\n",
    "              input_size=24,\n",
    "              loss=GMM(n_components=10, level=[90]),\n",
    "              hist_exog_list=['y_[lag12]'],\n",
    "              max_steps=1,\n",
    "              early_stop_patience_steps=10,\n",
    "              val_check_steps=50,\n",
    "              scaler_type='revin',\n",
    "              learning_rate=1e-3)\n",
    "nf = NeuralForecast(models=[model], freq='MS')\n",
    "Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n",
    "\n",
    "# Instantiate BaseWindow model and test revin dynamic dimensionality with hist_exog_list\n",
    "model = NHITS(h=12,\n",
    "              input_size=24,\n",
    "              loss=HuberLoss(),\n",
    "              hist_exog_list=['trend', 'y_[lag12]'],\n",
    "              max_steps=1,\n",
    "              early_stop_patience_steps=10,\n",
    "              val_check_steps=50,\n",
    "              scaler_type='revin',\n",
    "              learning_rate=1e-3)\n",
    "nf = NeuralForecast(models=[model], freq='MS')\n",
    "Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n",
    "\n",
    "# Instantiate BaseRecurrent model and test revin dynamic dimensionality with hist_exog_list\n",
    "model = RNN(h=12,\n",
    "              input_size=24,\n",
    "              loss=GMM(n_components=10, level=[90]),\n",
    "              hist_exog_list=['trend', 'y_[lag12]'],\n",
    "              max_steps=1,\n",
    "              early_stop_patience_steps=10,\n",
    "              val_check_steps=50,\n",
    "              scaler_type='revin',\n",
    "              learning_rate=1e-3)\n",
    "nf = NeuralForecast(models=[model], freq='MS')\n",
    "Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n",
    "\n",
    "# Instantiate BaseRecurrent model and test revin dynamic dimensionality with hist_exog_list\n",
    "model = RNN(h=12,\n",
    "              input_size=24,\n",
    "              loss=HuberLoss(),\n",
    "              hist_exog_list=['trend'],\n",
    "              max_steps=1,\n",
    "              early_stop_patience_steps=10,\n",
    "              val_check_steps=50,\n",
    "              scaler_type='revin',\n",
    "              learning_rate=1e-3)\n",
    "nf = NeuralForecast(models=[model], freq='MS')\n",
    "Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2f50bd8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
